// SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
// Copyright (c) 2020-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
#pragma once

#include <map>
#include <memory>
#include <mutex>
#include <shared_mutex>  // NOLINT(build/include_order)
#include <unordered_map>
#include <utility>
#include <vector>

#include "common/unique_index_map.hpp"
#include "gems/core/math/pose2.hpp"
#include "gems/core/math/pose3.hpp"
#include "gems/gxf_helpers/expected_macro_gxf.hpp"
#include "gems/pose_tree/pose_tree_edge_history.hpp"
#include "gxf/core/component.hpp"
#include "gxf/std/gems/suballocators/first_fit_allocator.hpp"

namespace nvidia {
namespace isaac {

// A temporal pose tree to store relative coordinate system transformations over time.
// This implementation does not support multiple paths between the same coordinate systems at a
// given time. It does however allow to disconnect edge and create new connection using a different
// path. It also allows for multiple "roots". In fact the transformation relationships form an
// acylic, bi-directional, not necessarily fully-connected graph.
// This PoseTree assigned a different version id to each operation that affects it, and this version
// can be used to make a query ignore later changes made to the tree.
class PoseTree : public gxf::Component {
 public:
  // Error codes used by this class.
  enum class Error {
    // kInvalidArgument is returned when a function is called with argument that does not make sense
    // such as negative number of frames.
    kInvalidArgument = 0,
    // kOutOfMemory is returned if `initialize` failed to allocate the requested memory, or if an
    // edge/frame can't be added because we run out of the pre-allocated memory.
    kOutOfMemory = 1,
    // kFrameNotFound is returned if a query is made with a frame uid that does not match any
    // existing frame.
    kFrameNotFound = 2,
    // kAlreadyExists is returned if a frame or an edge that already exist is added.
    kAlreadyExists = 3,
    // kCyclingDependency is returned if a pose is added that would create a cycle in the PoseTree
    // structure.
    kCyclingDependency = 4,
    // kFramesNotLinked is returned if a query is made between two not connected frame or if we
    // attempt to disconnect/delete an edge that does not exist.
    kFramesNotLinked = 5,
    // kPoseOutOfOrder is returned if a query is made to update the three in the past. For example
    // if we try to disconnect or update a pose at a time older than the latest update on this edge.
    kPoseOutOfOrder = 6,
    // kLogicError is used whenever an error that should not have happened happened. This should
    // never happen and are here only to prevent crashes/assert,
    kLogicError = 7,
  };

  // The maximum size for the name of a frame. An additional '\0' is added at the end to make it
  // 64 characters long.
  static constexpr int32_t kFrameNameMaximumLength = 63;

  // Auto generated frame names will start with this prefix, followed by the uid of the frame.
  static constexpr char const* kAutoGeneratedFrameNamePrefix = "_frame_";

  // Expected type used by this class.
  template <typename T>
  using Expected = nvidia::Expected<T, Error>;
  // Unexpected type used by this class.
  using Unexpected = nvidia::Unexpected<Error>;

  // Type used to uniquely identify a frame.
  using frame_t = uint64_t;
  // Type used for versioning the PoseTree.
  using version_t = uint64_t;
  // Type used as a key for the PoseTreeEdgeHistory map.
  using history_t = uint64_t;

  // Type for callback functions that are called every time a frame is created.
  using CreateFrameCallback = std::function<void(frame_t frame)>;
  // Type for callback functions that are called every time an edge is set.
  using SetEdgeCallback = std::function<void(frame_t lhs, frame_t rhs, double time,
                                             const ::nvidia::isaac::Pose3d& lhs_T_rhs)>;

  // Allocates space for a given number of total frames and total number of edges.
  // Total amount of memory required is approximately:
  //  number_frames * 128 + number_edges * 64 + history_length * 72.
  Expected<void> init(int32_t number_frames, int32_t number_edges, int32_t history_length,
                      int32_t default_number_edges, int32_t default_history_length,
                      int32_t edges_chunk_size, int32_t history_chunk_size);
  void deinit();

  // Returns the current PoseTree version.
  version_t getPoseTreeVersion() const;

  // Creates a new frame in the PoseTree. An optional name may be given to give a human-readable
  // name to the frame. The name is a null-terminated string with at most 63 characters. User
  // defined name cannot start with "_", which is reserved for auto generated names such as
  // "_frame_i", where i is the uid of the frame. hint on the maximum number of edges this frame
  // will be connected to can be provided. Returns the frame id.
  Expected<frame_t> createFrame(const char* name, int32_t number_edges);
  Expected<frame_t> createFrame(const char* name);
  Expected<frame_t> createFrame(int32_t number_edges);
  Expected<frame_t> createFrame();

  // Finds a frame with the given name, and returns the frame id. If no such frame exist, it returns
  // Error::kFrameNotFound
  Expected<frame_t> findFrame(const char* name) const;
  // Finds a frame with the given name, and returns the frame id. If no such frame exist, it creates
  // a new frame by calling the equivalent createFrame function.
  Expected<frame_t> findOrCreateFrame(const char* name);
  Expected<frame_t> findOrCreateFrame(const char* name, int32_t number_edges);

  // Creates an edge between the left hand side (lhs) frame and the right hand side (rhs) frame.
  // A hint on the maximum length needed can be provided.
  // Upon success, it returns the version id of the change.
  Expected<version_t> createEdges(frame_t lhs, frame_t rhs);
  Expected<version_t> createEdges(frame_t lhs, frame_t rhs, int32_t maximum_length);
  Expected<version_t> createEdges(frame_t lhs, frame_t rhs,
                                  PoseTreeEdgeHistory::AccessMethod method);
  Expected<version_t> createEdges(frame_t lhs, frame_t rhs, int32_t maximum_length,
                                  PoseTreeEdgeHistory::AccessMethod method);

  // Deletes a frame in the PoseTree and all its relations to other frames and frees its memory.
  // This action is permantly erasing the history information.
  // Upon success, it returns the version id of the change (however query made with a previous
  // version will also consider the frame as deleted).
  Expected<version_t> deleteFrame(frame_t uid);
  // Deletes an edge and frees the memory.
  // This action is permantly erasing the history.
  // Upon success, it returns the version id of the change (however query made with a previous
  // version will also consider the edge as deleted).
  Expected<version_t> deleteEdge(frame_t lhs, frame_t rhs);

  // Disconnects a frame from all the others starting at a given time.
  // Upon success, it returns the version id of the change.
  Expected<version_t> disconnectFrame(frame_t uid, double time);
  // Disconnects an edge starting at a given time.
  // Upon success, it returns the version id of the change.
  Expected<version_t> disconnectEdge(frame_t lhs, frame_t rhs, double time);

  // Disable all the implicit cast (to make sure to catch a call with the wrong type for the time)
  template <class ... Args> Expected<::nvidia::isaac::Pose3d>
  disconnectFrame(Args&&... args) = delete;
  template <class ... Args> Expected<::nvidia::isaac::Pose3d>
  disconnectEdge(Args&&... args) = delete;

  // Gets the name of a frame.
  Expected<const char*> getFrameName(frame_t uid) const;

  // Gets the latest pose between two frames as well as the time of that pose.
  // The two poses needs to be directly linked
  Expected<std::pair<::nvidia::isaac::Pose3d, double>> getLatest(frame_t lhs, frame_t rhs) const;
  Expected<std::pair<::nvidia::isaac::Pose3d, double>>
  getLatest(const char* lhs, const char* rhs) const;

  // Gets the pose lhs_T_rhs between two frames in the PoseTree at the given time. If the poses are
  // not connected exactly at the given time, the indicated method is used to interpolate the data.
  Expected<::nvidia::isaac::Pose3d> get(frame_t lhs, frame_t rhs, double time,
                                PoseTreeEdgeHistory::AccessMethod method,
                                version_t version) const;
  Expected<::nvidia::isaac::Pose3d> get(frame_t lhs,
                                        frame_t rhs,
                                        double time,
                                        version_t version) const;
  Expected<::nvidia::isaac::Pose3d> get(frame_t lhs, frame_t rhs, double time,
                                PoseTreeEdgeHistory::AccessMethod method) const;
  Expected<::nvidia::isaac::Pose3d> get(frame_t lhs, frame_t rhs, double time) const;

  // Same as above, but using the name as interface.
  Expected<::nvidia::isaac::Pose3d> get(const char* lhs, const char* rhs, double time,
                                PoseTreeEdgeHistory::AccessMethod method,
                                version_t version) const;
  Expected<::nvidia::isaac::Pose3d> get(const char* lhs, const char* rhs, double time,
                                version_t version) const;
  Expected<::nvidia::isaac::Pose3d> get(const char* lhs, const char* rhs, double time,
                                PoseTreeEdgeHistory::AccessMethod method) const;
  Expected<::nvidia::isaac::Pose3d> get(const char* lhs, const char* rhs, double time) const;

  // Disable all the implicit cast (to make sure to catch a call with the wrong type for the time)
  template <class ... Args> Expected<::nvidia::isaac::Pose3d> get(Args&&... args) const = delete;

  // Helper function to get a Pose2d instead of Pose3d
  template <class ... Args>
  Expected<::nvidia::isaac::Pose2d> getPose2XY(Args&&... args) const {
    return get(std::forward<Args>(args)...).map([](const ::nvidia::isaac::Pose3d& pose_3d) {
      return pose_3d.toPose2XY();
    });
  }

  // The last time at which the pose between two frames is specified.
  // TODO(bbutin): Implement the following methods.
  // Expected<timed_pose_t> latest(frame_t lhs, frame_t rhs, version_t version) const;
  // Expected<timed_pose_t> latest(frame_t lhs, frame_t rhs) const;

  // Sets the pose between two frames in the PoseTree. Note that poses can not be changed
  // retrospectively. Thus for example once the pose at time t=2.0 is set it is no longer allowed
  // to set the pose for time t <= 2.0. It is not allowed to form cycles. Frames are implicitly
  // linked. If more than the maximum number of allowed poses are set the oldest pose is deleted.
  // Upon success, it returns the version id of the change.
  Expected<version_t> set(frame_t lhs, frame_t rhs, double time,
                          const ::nvidia::isaac::Pose3d& lhs_T_rhs);
  // Same as above, but using the name as interface.
  Expected<version_t> set(const char* lhs, const char* rhs, double time,
                          const ::nvidia::isaac::Pose3d& lhs_T_rhs);

  // Helper function to set a Pose2d instead of Pose3d
  Expected<version_t> set(frame_t lhs, frame_t rhs, double time,
                          const ::nvidia::isaac::Pose2d& lhs_T_rhs) {
    return set(lhs, rhs, time, ::nvidia::isaac::Pose3d::FromPose2XY(lhs_T_rhs));
  }
  // Same as above, but using the name as interface.
  Expected<version_t> set(const char* lhs, const char* rhs, double time,
                          const ::nvidia::isaac::Pose2d& lhs_T_rhs) {
    return set(lhs, rhs, time, ::nvidia::isaac::Pose3d::FromPose2XY(lhs_T_rhs));
  }

  // Disable all the implicit cast (to make sure to catch a call with the wrong type for the time)
  // First we define a function that will match all the call with double and forward it to the
  // appropirate function. This is needed in case the Pose is provided by value.
  template <typename Frame, typename Pose>
  Expected<version_t> set(Frame lhs, Frame rhs, double time, const Pose lhs_T_rhs) {
    return set(lhs, rhs, time, lhs_T_rhs);
  }
  // Then we disable all the calls not made with double.
  template <typename Frame, typename T, typename Pose>
  Expected<version_t> set(Frame lhs, Frame rhs, T time, const Pose& lhs_T_rhs) = delete;

  // Get list of edges
  std::vector<std::pair<frame_t, frame_t>> edges() const;

  // Registers a callback function for every time a frame is created. If a callback function
  // does not already exist for the provided commponent id cid.
  Expected<void> addCreateFrameCallback(gxf_uid_t cid, CreateFrameCallback callback);

  // Deregisters a callback function for time an edge is set. If there is no callback function for
  // the component id cid, this function returns an error.
  Expected<void> removeCreateFrameCallback(gxf_uid_t cid);

  // Registers a callback function for every time an edge is set. If a callback function
  // does not already exist for the provided commponent id cid.
  Expected<void> addSetEdgeCallback(gxf_uid_t cid, SetEdgeCallback callback);

  // Deregisters a callback function for time an edge is set. If there is no callback function for
  // the component id cid, this function returns an error.
  Expected<void> removeSetEdgeCallback(gxf_uid_t cid);

  // Helper function to transform an error code into an human readable error.
  static const char* ErrorToStr(Error error);

 private:
  // Helper structure that stores the information about a frame.
  struct FrameInfo {
    // Array containg the list of edges.
    history_t* history;
    // Current number of edges
    int32_t number_edges;
    // Maximum number of edges allowed
    int32_t maximum_number_edges;
    // Name of the frame. It has to be null terminated, so it can hold at most 63 characters.
    char name[kFrameNameMaximumLength + 1];
    // Hint to quickly find a path:
    // Store the distance from the node to the root (== 0 if this frame is the root)
    int32_t distance_to_root;
    // Frame to follow to reach the root
    frame_t node_to_root;
    // Name of the root
    frame_t root;
    // Some helper id to computer the path between two nodes
    mutable version_t hint_version;
    // Some helper to memorize the path we took during the dfs
    mutable frame_t dfs_link;
    // Name of the frame.
    frame_t uid;
  };
  // Helper class to compare to const char*
  // TODO(ben): Remove once we get rid of map
  struct CharMapCompare {
    bool operator()(const char* lhs, const char* rhs) const {
      return std::strcmp(lhs, rhs) < 0;
    }
  };

  // Implementation of findOrCreateFrame
  Expected<frame_t> findOrCreateFrameImpl(const char* name, int32_t number_edges);
  // Implementation of findFrame
  Expected<frame_t> findFrameImpl(const char* name) const;
  // Implementation of createFrame
  Expected<frame_t> createFrameImpl(const char* name, int32_t number_edges);
  // Implementation of createEdges
  Expected<version_t> createEdgesImpl(frame_t lhs, frame_t rhs, int32_t maximum_length,
                                      PoseTreeEdgeHistory::AccessMethod method);
  // Implementation of deleteEdge
  Expected<version_t> deleteEdgeImpl(frame_t lhs, frame_t rhs, version_t version);
  // Update the path to the root for a given connected component starting from the given node.
  Expected<void> updateRoot(frame_t root);
  // Implementation of get using the pre-computed path to the root as a hint. If it fails, it falls
  // back to getDfsImpl.
  Expected<::nvidia::isaac::Pose3d> getImpl(frame_t lhs, frame_t rhs, double time,
                                    PoseTreeEdgeHistory::AccessMethod method,
                                    version_t version) const;
  // Implementation of get that do a dfs to see if a path exists at a given time.
  Expected<::nvidia::isaac::Pose3d> getDfsImpl(frame_t lhs, frame_t rhs, double time,
                                       PoseTreeEdgeHistory::AccessMethod method,
                                       version_t version) const;

  // Lock to protect access to the parameter below.
  mutable std::shared_timed_mutex mutex_;
  // Lock to protect access to getDfsImpl. This function is rarely called, but can be called while
  // mutex_ is lock in read access, and getDfsImpl is modifying the dfs_link of some frames as well
  // as using the frames_stack_. We need a special protection for this function while not blocking
  // all the concurrent read which most likely won't call it.
  mutable std::mutex dfs_mutex_;

  // Lock to protect create_frame_callbacks_
  mutable std::shared_timed_mutex create_frame_callbacks_mutex_;
  // Callback functions for the create frame operation.
  // TODO(dbhaskara): Replace with UniqueIndexMap once it supports iteration through all elements
  std::unordered_map<gxf_uid_t, CreateFrameCallback> create_frame_callbacks_;

  // Lock to protect set_edge_callbacks_
  mutable std::shared_timed_mutex set_edge_callbacks_mutex_;
  // Callback functions for the set edge operation.
  // TODO(dbhaskara): Replace with UniqueIndexMap once it supports iteration through all elements
  std::unordered_map<gxf_uid_t, SetEdgeCallback> set_edge_callbacks_;

  // Mapping from a frame to it's index.
  // TODO(ben): We need to get rid of std::map
  std::map<std::pair<frame_t, frame_t>, history_t> edges_map_;

  // TODO(ben): We need to get rid of std::map, but for now UniqueIndexMap does not support
  // iterating through all the elements.
  std::map<const char*, frame_t, CharMapCompare> name_to_uid_map_;

  // Store the list of the current frame of the PoseTree.
  UniqueIndexMap<FrameInfo> frame_map_;

  // Used to implement a dfs.
  std::unique_ptr<frame_t[]> frames_stack_;

  // Store the list of PoseTreeEdgeHistory used by the frames. Each PoseTreeEdgeHistory correspond
  // to a bi-directional edge.
  UniqueIndexMap<PoseTreeEdgeHistory> histories_map_;

  // Helper to `allocate` an array of PoseTreeEdgeHistory (storing only the uid).
  gxf::FirstFitAllocator<history_t> histories_management_;

  // Helper to `allocate` an array of TimedPose.
  gxf::FirstFitAllocator<PoseTreeEdgeHistory::TimedPose> poses_management_;

  // Current version of the PoseTree.
  frame_t version_;
  // Version of the hint. Mostly used to know if a node in the stack has been processed already.
  mutable frame_t hint_version_;

  // Default maximum number of edges a given frame can have
  int32_t default_number_edges_;
  // Default length of the history used by an edge.
  int32_t default_history_length_;
};

}  // namespace isaac
}  // namespace nvidia

// We configure the expected macro to work with PoseTree::Error. The configuration has to happen in
// exactly this namespace.
namespace nvidia::expected_macro {
template <>
struct IsStatus<::nvidia::isaac::PoseTree::Error> : std::true_type {};
}  // namespace nvidia::expected_macro
